{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `relay.Span` 示例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Span((nullptr), 1, 2, 3, 4)'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "span = relay.Span(None, 1, 2, 3, 4)\n",
    "assert span.source_name == None\n",
    "assert span.line == 1\n",
    "assert span.end_line == 2\n",
    "assert span.column == 3\n",
    "assert span.end_column == 4\n",
    "assert span.same_as(span)\n",
    "assert span == span\n",
    "assert isinstance(span, relay.base.Span)\n",
    "str(span)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "back = tvm.ir.load_json(tvm.ir.save_json(span))\n",
    "assert back.source_name == span.source_name\n",
    "assert back.line == span.line\n",
    "assert back.end_line == span.end_line\n",
    "assert back.column == span.column\n",
    "assert back.end_column == span.end_column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.relay import testing\n",
    "import numpy as np\n",
    "from tvm.relay import Expr\n",
    "from tvm.relay.analysis import free_vars\n",
    "\n",
    "def astext(program, unify_free_vars=False):\n",
    "    text = program.astext()\n",
    "    if isinstance(program, Expr):\n",
    "        roundtrip_program = tvm.relay.parse_expr(text)\n",
    "    else:\n",
    "        roundtrip_program = tvm.relay.fromtext(text)\n",
    "    tvm.ir.assert_structural_equal(roundtrip_program, program, map_free_vars=True)\n",
    "    return text\n",
    "\n",
    "x = relay.var(\"x\", shape=(3, 2))\n",
    "y = relay.var(\"y\")\n",
    "one = relay.const(10e10, dtype=\"float32\")\n",
    "z = relay.add(x, one)\n",
    "z = relay.Call(\n",
    "    z.op, z.args, z.attrs, z.type_args, relay.Span(relay.SourceName(\"Add0\"), 0, 0, 0, 0)\n",
    ")\n",
    "z = relay.add(z, z)\n",
    "z = relay.Call(\n",
    "    z.op, z.args, z.attrs, z.type_args, relay.Span(relay.SourceName(\"Add1\"), 0, 0, 0, 0)\n",
    ")\n",
    "f = relay.Function([x, y], z)\n",
    "txt = astext(f)\n",
    "assert \"Add0\" in txt\n",
    "assert \"Add1\" in txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def @main(%data: Tensor[(1, 3, 64, 64), float32] /* ty=Tensor[(1, 3, 64, 64), float32] */, %weight: Tensor[(3, 3, 3, 3), float32] /* ty=Tensor[(3, 3, 3, 3), float32] */, %bn_gamma: Tensor[(3), float32] /* ty=Tensor[(3), float32] */, %bn_beta: Tensor[(3), float32] /* ty=Tensor[(3), float32] */, %bn_mean: Tensor[(3), float32] /* ty=Tensor[(3), float32] */, %bn_var: Tensor[(3), float32] /* ty=Tensor[(3), float32] */) -> Tensor[(1, 3, 64, 64), float32] {\n",
      "  %0 = nn.conv2d(%data, %weight, padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3]) /* ty=Tensor[(1, 3, 64, 64), float32] */;\n",
      "  %1 = nn.batch_norm(%0, %bn_gamma, %bn_beta, %bn_mean, %bn_var) /* ty=(Tensor[(1, 3, 64, 64), float32], Tensor[(3), float32], Tensor[(3), float32]) */;\n",
      "  %1.0 /* ty=Tensor[(1, 3, 64, 64), float32] */\n",
      "}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[16:05:01] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:181: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n"
     ]
    }
   ],
   "source": [
    "# 参考：https://github.com/apache/tvm/blob/main/tests/python/relay/test_pass_annotate_spans_defuse.py\n",
    "data = relay.var(\"data\", relay.TensorType((1, 3, 64, 64), \"float32\"))\n",
    "weight = relay.var(\"weight\")\n",
    "\n",
    "bn_gamma = relay.var(\"bn_gamma\")\n",
    "bn_beta = relay.var(\"bn_beta\")\n",
    "bn_mmean = relay.var(\"bn_mean\")\n",
    "bn_mvar = relay.var(\"bn_var\")\n",
    "\n",
    "simple_net = relay.nn.conv2d(\n",
    "    data=data, weight=weight, kernel_size=(3, 3), channels=3, padding=(1, 1)\n",
    ")\n",
    "simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0]\n",
    "simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net)\n",
    "\n",
    "module, params = relay.testing.create_workload(simple_net)\n",
    "print(module)\n",
    "# 应用一些简单的通道使 IR 合法化\n",
    "with tvm.transform.PassContext(opt_level=0):\n",
    "    module, params = relay.optimize(\n",
    "        module, target=tvm.testing.enabled_targets()[0][0], params=params\n",
    "    )\n",
    "seq = tvm.transform.Sequential([relay.transform.AnnotateSpans(), \n",
    "                                relay.transform.DefuseOps()])\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    module = seq(module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def @main(%data {virtual_device=VirtualDevice(device_type=1, virtual_device_id=0, target=Target(id=39d9420, kind='llvm', keys={'cpu'}, host=Target(id=396cb40, kind='llvm', keys={'cpu'})))}: Tensor[(1, 3, 64, 64), float32] /* ty=Tensor[(1, 3, 64, 64), float32] span=GeneratedSource:21:11 */, hash=\"145b385fdff2c9c3\", virtual_device=VirtualDevice(device_type=1, virtual_device_id=0, target=Target(id=39d9420, kind='llvm', keys={'cpu'}, host=Target(id=396cb40, kind='llvm', keys={'cpu'})))) -> Tensor[(1, 3, 64, 64), float32] {\n",
      "  %0 = add(meta[relay.Constant][1] /* ty=Tensor[(3), float32] span=GeneratedSource:9:16 */, 1e-05f /* ty=float32 span=GeneratedSource:9:72 */) /* ty=Tensor[(3), float32] span=GeneratedSource:7:5 */;\n",
      "  %1 = rsqrt(%0) /* ty=Tensor[(3), float32] span=GeneratedSource:11:5 */;\n",
      "  %2 = multiply(%1, meta[relay.Constant][2] /* ty=Tensor[(3), float32] span=GeneratedSource:17:20 */) /* ty=Tensor[(3), float32] span=GeneratedSource:15:6 */;\n",
      "  %3 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(3, 3, 3, 3), float32] span=GeneratedSource:21:23 */, padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3]) /* ty=Tensor[(1, 3, 64, 64), float32] span=GeneratedSource:4:5 */;\n",
      "  %4 = expand_dims(%2, axis=1, num_newaxis=2) /* ty=Tensor[(3, 1, 1), float32] span=GeneratedSource:19:5 */;\n",
      "  %5 = negative(meta[relay.Constant][3] /* ty=Tensor[(3), float32] span=GeneratedSource:29:18 */) /* ty=Tensor[(3), float32] span=GeneratedSource:27:5 */;\n",
      "  %6 = multiply(%5, %2) /* ty=Tensor[(3), float32] span=GeneratedSource:31:6 */;\n",
      "  %7 = add(%6, meta[relay.Constant][4] /* ty=Tensor[(3), float32] span=GeneratedSource:37:23 */) /* ty=Tensor[(3), float32] span=GeneratedSource:35:5 */;\n",
      "  %8 = multiply(%3, %4) /* ty=Tensor[(1, 3, 64, 64), float32] span=GeneratedSource:24:6 */;\n",
      "  %9 = expand_dims(%7, axis=1, num_newaxis=2) /* ty=Tensor[(3, 1, 1), float32] span=GeneratedSource:39:5 */;\n",
      "  add(%8, %9) /* ty=Tensor[(1, 3, 64, 64), float32] span=GeneratedSource:44:5 */\n",
      "}\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
